{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "16b37a19-0509-4842-beda-899b3c1bf9ac",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "60000\n",
      "(tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000, 0.0510,\n",
      "          0.2863, 0.0000, 0.0000, 0.0039, 0.0157, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0039, 0.0039, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0118, 0.0000, 0.1412, 0.5333,\n",
      "          0.4980, 0.2431, 0.2118, 0.0000, 0.0000, 0.0000, 0.0039, 0.0118,\n",
      "          0.0157, 0.0000, 0.0000, 0.0118],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0235, 0.0000, 0.4000, 0.8000,\n",
      "          0.6902, 0.5255, 0.5647, 0.4824, 0.0902, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0471, 0.0392, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6078, 0.9255,\n",
      "          0.8118, 0.6980, 0.4196, 0.6118, 0.6314, 0.4275, 0.2510, 0.0902,\n",
      "          0.3020, 0.5098, 0.2824, 0.0588],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.2706, 0.8118, 0.8745,\n",
      "          0.8549, 0.8471, 0.8471, 0.6392, 0.4980, 0.4745, 0.4784, 0.5725,\n",
      "          0.5529, 0.3451, 0.6745, 0.2588],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0039, 0.0039, 0.0039, 0.0000, 0.7843, 0.9098, 0.9098,\n",
      "          0.9137, 0.8980, 0.8745, 0.8745, 0.8431, 0.8353, 0.6431, 0.4980,\n",
      "          0.4824, 0.7686, 0.8980, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7176, 0.8824, 0.8471,\n",
      "          0.8745, 0.8941, 0.9216, 0.8902, 0.8784, 0.8706, 0.8784, 0.8667,\n",
      "          0.8745, 0.9608, 0.6784, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7569, 0.8941, 0.8549,\n",
      "          0.8353, 0.7765, 0.7059, 0.8314, 0.8235, 0.8275, 0.8353, 0.8745,\n",
      "          0.8627, 0.9529, 0.7922, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0039, 0.0118, 0.0000, 0.0471, 0.8588, 0.8627, 0.8314,\n",
      "          0.8549, 0.7529, 0.6627, 0.8902, 0.8157, 0.8549, 0.8784, 0.8314,\n",
      "          0.8863, 0.7725, 0.8196, 0.2039],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0235, 0.0000, 0.3882, 0.9569, 0.8706, 0.8627,\n",
      "          0.8549, 0.7961, 0.7765, 0.8667, 0.8431, 0.8353, 0.8706, 0.8627,\n",
      "          0.9608, 0.4667, 0.6549, 0.2196],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0157, 0.0000, 0.0000, 0.2157, 0.9255, 0.8941, 0.9020,\n",
      "          0.8941, 0.9412, 0.9098, 0.8353, 0.8549, 0.8745, 0.9176, 0.8510,\n",
      "          0.8510, 0.8196, 0.3608, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0039, 0.0157, 0.0235, 0.0275, 0.0078, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.9294, 0.8863, 0.8510, 0.8745,\n",
      "          0.8706, 0.8588, 0.8706, 0.8667, 0.8471, 0.8745, 0.8980, 0.8431,\n",
      "          0.8549, 1.0000, 0.3020, 0.0000],\n",
      "         [0.0000, 0.0118, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.2431, 0.5686, 0.8000, 0.8941, 0.8118, 0.8353, 0.8667,\n",
      "          0.8549, 0.8157, 0.8275, 0.8549, 0.8784, 0.8745, 0.8588, 0.8431,\n",
      "          0.8784, 0.9569, 0.6235, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0706, 0.1725, 0.3216, 0.4196,\n",
      "          0.7412, 0.8941, 0.8627, 0.8706, 0.8510, 0.8863, 0.7843, 0.8039,\n",
      "          0.8275, 0.9020, 0.8784, 0.9176, 0.6902, 0.7373, 0.9804, 0.9725,\n",
      "          0.9137, 0.9333, 0.8431, 0.0000],\n",
      "         [0.0000, 0.2235, 0.7333, 0.8157, 0.8784, 0.8667, 0.8784, 0.8157,\n",
      "          0.8000, 0.8392, 0.8157, 0.8196, 0.7843, 0.6235, 0.9608, 0.7569,\n",
      "          0.8078, 0.8745, 1.0000, 1.0000, 0.8667, 0.9176, 0.8667, 0.8275,\n",
      "          0.8627, 0.9098, 0.9647, 0.0000],\n",
      "         [0.0118, 0.7922, 0.8941, 0.8784, 0.8667, 0.8275, 0.8275, 0.8392,\n",
      "          0.8039, 0.8039, 0.8039, 0.8627, 0.9412, 0.3137, 0.5882, 1.0000,\n",
      "          0.8980, 0.8667, 0.7373, 0.6039, 0.7490, 0.8235, 0.8000, 0.8196,\n",
      "          0.8706, 0.8941, 0.8824, 0.0000],\n",
      "         [0.3843, 0.9137, 0.7765, 0.8235, 0.8706, 0.8980, 0.8980, 0.9176,\n",
      "          0.9765, 0.8627, 0.7608, 0.8431, 0.8510, 0.9451, 0.2549, 0.2863,\n",
      "          0.4157, 0.4588, 0.6588, 0.8588, 0.8667, 0.8431, 0.8510, 0.8745,\n",
      "          0.8745, 0.8784, 0.8980, 0.1137],\n",
      "         [0.2941, 0.8000, 0.8314, 0.8000, 0.7569, 0.8039, 0.8275, 0.8824,\n",
      "          0.8471, 0.7255, 0.7725, 0.8078, 0.7765, 0.8353, 0.9412, 0.7647,\n",
      "          0.8902, 0.9608, 0.9373, 0.8745, 0.8549, 0.8314, 0.8196, 0.8706,\n",
      "          0.8627, 0.8667, 0.9020, 0.2627],\n",
      "         [0.1882, 0.7961, 0.7176, 0.7608, 0.8353, 0.7725, 0.7255, 0.7451,\n",
      "          0.7608, 0.7529, 0.7922, 0.8392, 0.8588, 0.8667, 0.8627, 0.9255,\n",
      "          0.8824, 0.8471, 0.7804, 0.8078, 0.7294, 0.7098, 0.6941, 0.6745,\n",
      "          0.7098, 0.8039, 0.8078, 0.4510],\n",
      "         [0.0000, 0.4784, 0.8588, 0.7569, 0.7020, 0.6706, 0.7176, 0.7686,\n",
      "          0.8000, 0.8235, 0.8353, 0.8118, 0.8275, 0.8235, 0.7843, 0.7686,\n",
      "          0.7608, 0.7490, 0.7647, 0.7490, 0.7765, 0.7529, 0.6902, 0.6118,\n",
      "          0.6549, 0.6941, 0.8235, 0.3608],\n",
      "         [0.0000, 0.0000, 0.2902, 0.7412, 0.8314, 0.7490, 0.6863, 0.6745,\n",
      "          0.6863, 0.7098, 0.7255, 0.7373, 0.7412, 0.7373, 0.7569, 0.7765,\n",
      "          0.8000, 0.8196, 0.8235, 0.8235, 0.8275, 0.7373, 0.7373, 0.7608,\n",
      "          0.7529, 0.8471, 0.6667, 0.0000],\n",
      "         [0.0078, 0.0000, 0.0000, 0.0000, 0.2588, 0.7843, 0.8706, 0.9294,\n",
      "          0.9373, 0.9490, 0.9647, 0.9529, 0.9569, 0.8667, 0.8627, 0.7569,\n",
      "          0.7490, 0.7020, 0.7137, 0.7137, 0.7098, 0.6902, 0.6510, 0.6588,\n",
      "          0.3882, 0.2275, 0.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1569,\n",
      "          0.2392, 0.1725, 0.2824, 0.1608, 0.1373, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000]]]), 9)\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets\n",
    "from torchvision.transforms import ToTensor, Lambda, Compose\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Download training data from open datasets.\n",
    "training_data = datasets.FashionMNIST(\n",
    "    root=\"data\",\n",
    "    train=True,\n",
    "    download=True,\n",
    "    transform=ToTensor(),\n",
    ")\n",
    "print(len(training_data))\n",
    "print(training_data.__getitem__(0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "86a778cf-8648-43cf-bdce-50f36b07f862",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download test data from open datasets.\n",
    "test_data = datasets.FashionMNIST(\n",
    "    root=\"data\",\n",
    "    train=False,\n",
    "    download=True,\n",
    "    transform=ToTensor(),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "dfc62d43-bd7e-4bdb-8016-16d8402a5f37",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Shape of X [N, C, H, W]:  torch.Size([64, 1, 28, 28])\n",
      "Shape of y:  torch.Size([64]) torch.int64\n"
     ]
    }
   ],
   "source": [
    "batch_size = 64\n",
    "\n",
    "# Create data loaders.\n",
    "train_dataloader = DataLoader(training_data, batch_size=batch_size)\n",
    "test_dataloader = DataLoader(test_data, batch_size=batch_size)\n",
    "\n",
    "for X, y in test_dataloader:\n",
    "    print(\"Shape of X [N, C, H, W]: \", X.shape)\n",
    "    print(\"Shape of y: \", y.shape, y.dtype)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "2aaa1bf1-a258-4968-9895-2baced7b65e5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using cuda device\n",
      "NeuralNetwork(\n",
      "  (flatten): Flatten(start_dim=1, end_dim=-1)\n",
      "  (linear_relu_stack): Sequential(\n",
      "    (0): Linear(in_features=784, out_features=512, bias=True)\n",
      "    (1): ReLU()\n",
      "    (2): Linear(in_features=512, out_features=512, bias=True)\n",
      "    (3): ReLU()\n",
      "    (4): Linear(in_features=512, out_features=10, bias=True)\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "# Get cpu or gpu device for training.\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "print(f\"Using {device} device\")\n",
    "\n",
    "# Define model\n",
    "class NeuralNetwork(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(NeuralNetwork, self).__init__()\n",
    "        self.flatten = nn.Flatten()\n",
    "        self.linear_relu_stack = nn.Sequential(\n",
    "            nn.Linear(28*28, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(512, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(512, 10)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.flatten(x)\n",
    "        logits = self.linear_relu_stack(x)\n",
    "        return logits\n",
    "\n",
    "model = NeuralNetwork().to(device)\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "af800efc-ac4f-4c4e-b28a-39b70f23dbe4",
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_fn = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "134633b3-f9af-4f2c-99cb-fcdf4830fba2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(dataloader, model, loss_fn, optimizer):\n",
    "    size = len(dataloader.dataset)\n",
    "    model.train()\n",
    "    for batch, (X, y) in enumerate(dataloader):\n",
    "        X, y = X.to(device), y.to(device)\n",
    "\n",
    "        # Compute prediction error\n",
    "        pred = model(X)\n",
    "        loss = loss_fn(pred, y)\n",
    "\n",
    "        # Backpropagation\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        if batch % 100 == 0:\n",
    "            loss, current = loss.item(), batch * len(X)\n",
    "            print(f\"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "11bc9655-67bb-4728-965b-680d2cd63310",
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(dataloader, model, loss_fn):\n",
    "    size = len(dataloader.dataset)\n",
    "    num_batches = len(dataloader)\n",
    "    model.eval()\n",
    "    test_loss, correct = 0, 0\n",
    "    with torch.no_grad():\n",
    "        for X, y in dataloader:\n",
    "            X, y = X.to(device), y.to(device)\n",
    "            pred = model(X)\n",
    "            test_loss += loss_fn(pred, y).item()\n",
    "            correct += (pred.argmax(1) == y).type(torch.float).sum().item()\n",
    "    test_loss /= num_batches\n",
    "    correct /= size\n",
    "    print(f\"Test Error: \\n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d260203c-d6e2-4ade-ba1d-70efe9b12225",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1\n",
      "-------------------------------\n",
      "loss: 2.304709  [    0/60000]\n",
      "loss: 2.293681  [ 6400/60000]\n",
      "loss: 2.264333  [12800/60000]\n",
      "loss: 2.253231  [19200/60000]\n",
      "loss: 2.249592  [25600/60000]\n",
      "loss: 2.213964  [32000/60000]\n",
      "loss: 2.219313  [38400/60000]\n",
      "loss: 2.184856  [44800/60000]\n",
      "loss: 2.184898  [51200/60000]\n",
      "loss: 2.142305  [57600/60000]\n",
      "Test Error: \n",
      " Accuracy: 42.3%, Avg loss: 2.143535 \n",
      "\n",
      "Epoch 2\n",
      "-------------------------------\n",
      "loss: 2.149639  [    0/60000]\n",
      "loss: 2.145669  [ 6400/60000]\n",
      "loss: 2.077938  [12800/60000]\n",
      "loss: 2.094888  [19200/60000]\n",
      "loss: 2.050446  [25600/60000]\n",
      "loss: 1.984101  [32000/60000]\n",
      "loss: 2.004893  [38400/60000]\n",
      "loss: 1.926185  [44800/60000]\n",
      "loss: 1.935143  [51200/60000]\n",
      "loss: 1.850987  [57600/60000]\n",
      "Test Error: \n",
      " Accuracy: 59.1%, Avg loss: 1.858153 \n",
      "\n",
      "Epoch 3\n",
      "-------------------------------\n",
      "loss: 1.885805  [    0/60000]\n",
      "loss: 1.862879  [ 6400/60000]\n",
      "loss: 1.739061  [12800/60000]\n",
      "loss: 1.781191  [19200/60000]\n",
      "loss: 1.676659  [25600/60000]\n",
      "loss: 1.632570  [32000/60000]\n",
      "loss: 1.641042  [38400/60000]\n",
      "loss: 1.555164  [44800/60000]\n",
      "loss: 1.581083  [51200/60000]\n",
      "loss: 1.469921  [57600/60000]\n",
      "Test Error: \n",
      " Accuracy: 61.8%, Avg loss: 1.496637 \n",
      "\n",
      "Epoch 4\n",
      "-------------------------------\n",
      "loss: 1.557977  [    0/60000]\n",
      "loss: 1.533555  [ 6400/60000]\n",
      "loss: 1.381828  [12800/60000]\n",
      "loss: 1.452681  [19200/60000]\n",
      "loss: 1.339994  [25600/60000]\n",
      "loss: 1.341980  [32000/60000]\n",
      "loss: 1.345259  [38400/60000]\n",
      "loss: 1.281594  [44800/60000]\n",
      "loss: 1.315619  [51200/60000]\n",
      "loss: 1.215707  [57600/60000]\n",
      "Test Error: \n",
      " Accuracy: 63.0%, Avg loss: 1.243245 \n",
      "\n",
      "Epoch 5\n",
      "-------------------------------\n",
      "loss: 1.314489  [    0/60000]\n",
      "loss: 1.304992  [ 6400/60000]\n",
      "loss: 1.137102  [12800/60000]\n",
      "loss: 1.242498  [19200/60000]\n",
      "loss: 1.122548  [25600/60000]\n",
      "loss: 1.150157  [32000/60000]\n",
      "loss: 1.164254  [38400/60000]\n",
      "loss: 1.109600  [44800/60000]\n",
      "loss: 1.148404  [51200/60000]\n",
      "loss: 1.064258  [57600/60000]\n",
      "Test Error: \n",
      " Accuracy: 64.5%, Avg loss: 1.085036 \n",
      "\n",
      "Done!\n"
     ]
    }
   ],
   "source": [
    "epochs = 5\n",
    "for t in range(epochs):\n",
    "    print(f\"Epoch {t+1}\\n-------------------------------\")\n",
    "    train(train_dataloader, model, loss_fn, optimizer)\n",
    "    test(test_dataloader, model, loss_fn)\n",
    "print(\"Done!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "233d28dd-e8ff-484f-839f-72597a3377d6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved PyTorch Model State to model.pth\n"
     ]
    }
   ],
   "source": [
    "torch.save(model.state_dict(), \"model.pth\")\n",
    "print(\"Saved PyTorch Model State to model.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "a63a6958-0b7a-4b1e-9f22-492ceb70d789",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = NeuralNetwork()\n",
    "model.load_state_dict(torch.load(\"model.pth\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "f0eef92c-db9d-47bb-bb44-7fadb081d53f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0118, 0.0039, 0.0000, 0.0000, 0.0275,\n",
      "          0.0000, 0.1451, 0.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0078, 0.0000,\n",
      "          0.1059, 0.3294, 0.0431, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.4667, 0.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000,\n",
      "          0.3451, 0.5608, 0.4314, 0.0000, 0.0000, 0.0000, 0.0000, 0.0863,\n",
      "          0.3647, 0.4157, 0.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0157, 0.0000, 0.2078,\n",
      "          0.5059, 0.4706, 0.5765, 0.6863, 0.6157, 0.6510, 0.5294, 0.6039,\n",
      "          0.6588, 0.5490, 0.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0078, 0.0000, 0.0431, 0.5373,\n",
      "          0.5098, 0.5020, 0.6275, 0.6902, 0.6235, 0.6549, 0.6980, 0.5843,\n",
      "          0.5922, 0.5647, 0.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000,\n",
      "          0.0078, 0.0039, 0.0000, 0.0118, 0.0000, 0.0000, 0.4510, 0.4471,\n",
      "          0.4157, 0.5373, 0.6588, 0.6000, 0.6118, 0.6471, 0.6549, 0.5608,\n",
      "          0.6157, 0.6196, 0.0431, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0118, 0.0000, 0.0000, 0.3490, 0.5451, 0.3529,\n",
      "          0.3686, 0.6000, 0.5843, 0.5137, 0.5922, 0.6627, 0.6745, 0.5608,\n",
      "          0.6235, 0.6627, 0.1882, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0078, 0.0157,\n",
      "          0.0039, 0.0000, 0.0000, 0.0000, 0.3843, 0.5333, 0.4314, 0.4275,\n",
      "          0.4314, 0.6353, 0.5294, 0.5647, 0.5843, 0.6235, 0.6549, 0.5647,\n",
      "          0.6196, 0.6627, 0.4667, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0078, 0.0078, 0.0039, 0.0078, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.1020, 0.4235, 0.4588, 0.3882, 0.4353, 0.4588,\n",
      "          0.5333, 0.6118, 0.5255, 0.6039, 0.6039, 0.6118, 0.6275, 0.5529,\n",
      "          0.5765, 0.6118, 0.6980, 0.0000],\n",
      "         [0.0118, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0824,\n",
      "          0.2078, 0.3608, 0.4588, 0.4353, 0.4039, 0.4510, 0.5059, 0.5255,\n",
      "          0.5608, 0.6039, 0.6471, 0.6667, 0.6039, 0.5922, 0.6039, 0.5608,\n",
      "          0.5412, 0.5882, 0.6471, 0.1686],\n",
      "         [0.0000, 0.0000, 0.0902, 0.2118, 0.2549, 0.2980, 0.3333, 0.4627,\n",
      "          0.5020, 0.4824, 0.4353, 0.4431, 0.4627, 0.4980, 0.4902, 0.5451,\n",
      "          0.5216, 0.5333, 0.6275, 0.5490, 0.6078, 0.6314, 0.5647, 0.6078,\n",
      "          0.6745, 0.6314, 0.7412, 0.2431],\n",
      "         [0.0000, 0.2667, 0.3686, 0.3529, 0.4353, 0.4471, 0.4353, 0.4471,\n",
      "          0.4510, 0.4980, 0.5294, 0.5333, 0.5608, 0.4941, 0.4980, 0.5922,\n",
      "          0.6039, 0.5608, 0.5804, 0.4902, 0.6353, 0.6353, 0.5647, 0.5412,\n",
      "          0.6000, 0.6353, 0.7686, 0.2275],\n",
      "         [0.2745, 0.6627, 0.5059, 0.4078, 0.3843, 0.3922, 0.3686, 0.3804,\n",
      "          0.3843, 0.4000, 0.4235, 0.4157, 0.4667, 0.4706, 0.5059, 0.5843,\n",
      "          0.6118, 0.6549, 0.7451, 0.7451, 0.7686, 0.7765, 0.7765, 0.7333,\n",
      "          0.7725, 0.7412, 0.7216, 0.1412],\n",
      "         [0.0627, 0.4941, 0.6706, 0.7373, 0.7373, 0.7216, 0.6706, 0.6000,\n",
      "          0.5294, 0.4706, 0.4941, 0.4980, 0.5725, 0.7255, 0.7647, 0.8196,\n",
      "          0.8157, 1.0000, 0.8196, 0.6941, 0.9608, 0.9882, 0.9843, 0.9843,\n",
      "          0.9686, 0.8627, 0.8078, 0.1922],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0471, 0.2627, 0.4157, 0.6431, 0.7255,\n",
      "          0.7804, 0.8235, 0.8275, 0.8235, 0.8157, 0.7451, 0.5882, 0.3216,\n",
      "          0.0314, 0.0000, 0.0000, 0.0000, 0.6980, 0.8157, 0.7373, 0.6863,\n",
      "          0.6353, 0.6196, 0.5922, 0.0431],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000],\n",
      "         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "          0.0000, 0.0000, 0.0000, 0.0000]]])\n",
      "9\n",
      "Predicted: \"Ankle boot\", Actual: \"Ankle boot\"\n"
     ]
    }
   ],
   "source": [
    "classes = [\n",
    "    \"T-shirt/top\",\n",
    "    \"Trouser\",\n",
    "    \"Pullover\",\n",
    "    \"Dress\",\n",
    "    \"Coat\",\n",
    "    \"Sandal\",\n",
    "    \"Shirt\",\n",
    "    \"Sneaker\",\n",
    "    \"Bag\",\n",
    "    \"Ankle boot\",\n",
    "]\n",
    "\n",
    "model.eval()\n",
    "x, y = test_data[0][0], test_data[0][1]\n",
    "print(x)\n",
    "print(y)\n",
    "with torch.no_grad():\n",
    "    pred = model(x)\n",
    "    predicted, actual = classes[pred[0].argmax(0)], classes[y]\n",
    "    print(f'Predicted: \"{predicted}\", Actual: \"{actual}\"')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98ece637-669e-4eb5-b372-ef20b090c451",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
